'''Reproduces Paper Sec. 4.3, Supplement Sec. 5, reconstruction from gradient.
'''

# Enable import from parent package
import sys
import os
sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )

import dataio, utils, training, loss_functions, modules
from torch.utils.data import DataLoader
import torch
import configargparse
from scipy.io import loadmat

p = configargparse.ArgumentParser()
p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')

p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
p.add_argument('--experiment_name', type=str, required=True,
               help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')

# General training options
p.add_argument('--batch_size', type=int, default=32)
p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
p.add_argument('--num_epochs', type=int, default=80000,
               help='Number of epochs to train for.')

p.add_argument('--epochs_til_ckpt', type=int, default=1000,
               help='Time interval in seconds until checkpoint is saved.')
p.add_argument('--steps_til_summary', type=int, default=100,
               help='Time interval in seconds until tensorboard summary is saved.')
p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid'],
               help='Type of model to evaluate, default is sine.')
p.add_argument('--data', type=str, default='./fwi/data_cylinder_5.mat', required=False,
               help='Data file with the source/rec coordinates and data.')
p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
p.add_argument('--pretrain', default=False, action='store_true', help='don''t solve for velocity.')
p.add_argument('--load_model', type=str, default=None, required=False,
               help='Load pretrained model from checkpoint.')

opt = p.parse_args()

# we need to load source and receiver data generated by the principled solver for FWI
data = loadmat(opt.data)
source_coords = data['source']
rec_coords = data['receivers']
rec_val = data['rec_val']

dataset = dataio.InverseHelmholtz(source_coords, rec_coords, rec_val, sidelength=115, pretrain=opt.pretrain)
dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)

# Define the model.
N_src = source_coords.shape[0]
model = modules.SingleBVPNet(in_features=2, out_features=2 * N_src + 1, type=opt.model, final_layer_factor=1.)

if opt.load_model is not None:
    model.load_state_dict(torch.load(opt.load_model))

model.cuda()

# Define the loss
loss_fn = loss_functions.helmholtz_pml
summary_fn = utils.write_helmholtz_summary

root_path = os.path.join(opt.logging_root, opt.experiment_name)

training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
               steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
               model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad)
